-
Notifications
You must be signed in to change notification settings - Fork 608
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
Greptile SummaryThis PR implements CPU optimizations for FP8 operations by reducing attribute lookup overhead and caching expensive computations:
The changes are well-structured and follow established patterns in the codebase. Previous review threads have comprehensively covered potential memory management concerns - the current implementation properly handles reference counting with Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant App as Application
participant Linear as Linear Module
participant QT as QuantizedTensor
participant Quantizer as C++ Quantizer
participant PyAPI as Python C API
Note over App,PyAPI: Forward Pass with CPU Optimizations
App->>Linear: forward(inp)
Linear->>Linear: Cache inp.requires_grad, weight.requires_grad
Linear->>QT: Access dtype/requires_grad
QT-->>Linear: Return cached _dtype/_requires_grad
Linear->>Quantizer: create_tensor()
Quantizer->>Quantizer: Cache nvte_is_non_tn_fp8_gemm_supported()
Quantizer->>PyAPI: PyDict_New(), PyTuple_New()
PyAPI-->>Quantizer: kwargs, args
Quantizer->>PyAPI: PyObject_Call(Float8TensorClass)
PyAPI-->>Quantizer: Float8Tensor instance
Quantizer->>PyAPI: Py_DECREF(kwargs), Py_DECREF(args)
Quantizer-->>Linear: (TensorWrapper, py::object)
Linear->>QT: Access shape/is_cuda
QT-->>Linear: Return from cached _data/_transpose
Linear-->>App: output tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)logic: Critical logical error:
||should be&&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs. -
transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)style: commented-out code for
requires_gradcaching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/module/linear.py, line 484 (link)logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.
Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?
10 files reviewed, 3 comments
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.
The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move this declaration before line 125 (Configure A matrix) | |
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.
This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); | |
| Py_DECREF(kwargs); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.
This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| // Remove this line - device parameter is not accepted | |
| // PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
| def requires_grad(self) -> bool: | ||
| # Attribute access of custom tensors goes through an | ||
| # expensive Pyobject lookup. Since requires_grad is set during | ||
| # initialization and may be updated, we cache it in a member variable. | ||
| return self._requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: The requires_grad property getter accesses self._requires_grad at line 378, but this member variable is never initialized in the __new__ method (lines 353-364). This will cause an AttributeError when the property is first accessed.
The _requires_grad attribute must be initialized in __new__:
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| return self._requires_grad | |
| @property | |
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| # Fallback to parent if not cached yet | |
| if not hasattr(self, '_requires_grad'): | |
| self._requires_grad = torch.Tensor.requires_grad.fget(self) | |
| return self._requires_grad |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing initialization: The _requires_grad member variable needs to be initialized here to support the cached property added later in the PR. Without this, accessing the requires_grad property will fail.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| bool is_extension_initialized = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.
The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.
While dynamic device changes are rare, consider adding a bounds check:
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
| py::object scale_inv_py = py::cast(scale_inv); | ||
| // Initialize scale-inverse tensor | ||
| if (!scale_inv) { | ||
| scale_inv = at::reciprocal(scale); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.
The fix is to move line 146 after the initialization block:
| py::object scale_inv_py = py::cast(scale_inv); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object scale_inv_py = py::cast(*scale_inv); |
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: _dtype is cached in __new__ but _requires_grad is not initialized. The requires_grad property (line 378) tries to return self._requires_grad, which will raise AttributeError when accessed before the setter is called. This breaks the optimization for caching requires_grad.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 13 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Move scale_inv_py creation after initialization (line 150) |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(args); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same issue as line 166).
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs); | |
| Py_DECREF(args); |
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Logic error: The original requires_grad(inp, weight, bias) returns True if ANY tensor requires grad. This change requires ALL THREE to be True with and, breaking FP8 state management when bias is None or doesn't require grad.
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): | |
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 2 comments
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Unlike the shape property above which has a defensive check (raise RuntimeError("Both data and transpose are None")), is_cuda will raise AttributeError if both _data and _transpose are None. Consider adding the same defensive check for consistency:
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| elif self._transpose is not None: | |
| return self._transpose.is_cuda | |
| else: | |
| raise RuntimeError("Both data and transpose are None") |
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Same consistency issue as Float8Tensor - is_cuda will raise AttributeError if both _rowwise_data and _columnwise_data are None, while there's no defensive check like shape should have. Consider adding error handling for this edge case.
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 2 comments
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing defensive check for is_cuda property. The shape property above correctly raises RuntimeError when both _data and _transpose are None, but is_cuda will raise AttributeError in the same scenario.
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| elif self._transpose is not None: | |
| return self._transpose.is_cuda | |
| raise RuntimeError("Both data and transpose are None") |
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.is_cuda | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Both shape and is_cuda properties lack defensive checks for when both _rowwise_data and _columnwise_data are None. Consider adding error handling for consistency with Float8Tensor:
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return ( | |
| self._rowwise_data.shape | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.shape | |
| ) | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return ( | |
| self._rowwise_data.is_cuda | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.is_cuda | |
| ) | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.shape | |
| elif self._columnwise_data is not None: | |
| return self._columnwise_data.shape | |
| raise RuntimeError("Both rowwise_data and columnwise_data are None") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.is_cuda | |
| elif self._columnwise_data is not None: | |
| return self._columnwise_data.is_cuda | |
| raise RuntimeError("Both rowwise_data and columnwise_data are None") |
…nd transpose are None Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The dtype setter unconditionally emits a warning when dtype is set. However, this setter is called in normal operations like Float8Tensor._set_data (line 1004 of float8_tensor.py) and MXFP8Tensor._set_data when copying tensor attributes. This will cause spurious warnings during legitimate tensor operations.
Consider only warning when the value actually changes:
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| self._dtype = value | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The warning is emitted unconditionally, but this setter is called in Float8Tensor._set_data, MXFP8Tensor._set_data, and NVFP4Tensor._set_data during normal copy operations when any tensor attribute differs (size, stride, device, etc.), not just dtype. This will cause spurious warnings.
Consider only warning when the value actually changes:
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| self._dtype = value | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 1 comment
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The dtype setter unconditionally emits a warning, but this setter is called in normal _set_data operations (e.g., float8_tensor.py:1004, mxfp8_tensor.py:789, nvfp4_tensor.py:693) during legitimate tensor copy operations. This will cause spurious warnings every time tensor data is copied.
Consider only warning when the value actually changes:
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| self._dtype = value | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| @dtype.setter | |
| def dtype(self, value: torch.dtype) -> None: | |
| """Set dtype property""" | |
| # Update the cached value | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 3 comments
| static std::unordered_map<std::string, void *> symbol_cache; | ||
|
|
||
| // Check if symbol is already cached | ||
| auto it = symbol_cache.find(symbol); | ||
| FuncT *func; | ||
|
|
||
| if (it != symbol_cache.end()) { | ||
| func = reinterpret_cast<FuncT *>(it->second); | ||
| } else { | ||
| // Symbol not in cache, look it up and cache the result | ||
| void *ptr = get_symbol(symbol); | ||
| symbol_cache[symbol] = ptr; | ||
| func = reinterpret_cast<FuncT *>(ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Thread safety issue: The static unordered_map is accessed without synchronization. If call() is invoked concurrently from multiple threads with different symbols, concurrent reads/writes to the map can cause data races and undefined behavior.
Consider using a mutex or std::shared_mutex for thread-safe access:
static std::unordered_map<std::string, void *> symbol_cache;
static std::mutex cache_mutex;
std::lock_guard<std::mutex> lock(cache_mutex);
auto it = symbol_cache.find(symbol);
// ...| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Memory leak: inc_ref() increments the reference count, but PyDict_SetItemString also increments it. When the kwargs dict is destroyed via Py_DECREF, the objects inside get one decref, leaving them with an extra reference that's never released.
Fix by storing the py::object in a local variable:
py::object fp8_dtype_obj = py::cast(this->dtype);
PyDict_SetItemString(kwargs, "fp8_dtype", fp8_dtype_obj.ptr());This pattern repeats throughout this file for all inc_ref() calls.
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Spurious warning: The dtype setter unconditionally emits a warning, but it's called from _set_data() in Float8Tensor, MXFP8Tensor, and NVFP4Tensor during legitimate tensor copy operations. This will produce warnings in normal use.
Consider only warning when the value actually changes:
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| if hasattr(self, "_dtype") and self._dtype != value: | |
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") | |
| self._dtype = value |
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: vthumbe1503 <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 1 comment
| symbol_cache[symbol] = ptr; | ||
| }); | ||
|
|
||
| return (*reinterpret_cast<FuncT *>(symbol_cache[symbol]))(args...); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Data race: symbol_cache[symbol] write at line 57 and read at line 60 are not properly synchronized. While call_once protects this symbol's initialization, a concurrent thread calling call() with a different symbol can modify the unordered_map (line 57) while this thread reads from it (line 60). This is undefined behavior.
Fix by holding the mutex during the cache read:
std::call_once(*flag_ptr, [&]() {
void *ptr = get_symbol(symbol);
std::lock_guard<std::mutex> lock(init_mutex);
symbol_cache[symbol] = ptr;
});
FuncT* func;
{
std::lock_guard<std::mutex> lock(init_mutex);
func = reinterpret_cast<FuncT *>(symbol_cache[symbol]);
}
return (*func)(args...);Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 1 comment
| return self._data.shape | ||
| if self._transpose is not None: | ||
| transpose_shape = self._transpose.shape | ||
| return tuple(transpose_shape[1:]) + (transpose_shape[0],) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Returns tuple instead of torch.Size when computed from transpose, while self._data.shape returns torch.Size. Consider wrapping with torch.Size() for type consistency:
| return tuple(transpose_shape[1:]) + (transpose_shape[0],) | |
| return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: